#
# Copyright 2018 Asylo authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests for asylo.platform.arch.sgx.host_calls_generator.code_generator."""

from google.protobuf import text_format
from unittest import main
from unittest import TestCase
from asylo.platform.arch.sgx.host_calls_generator import code_generator


def _get_parameters_proto(host_calls_dictionary):
  """Get the FormalParameterProtos for the first host call in the dictionary."""
  return host_calls_dictionary['host_calls'][0].parameters


class CodeGeneratorTest(TestCase):
  """Tests for the host call code generator.

  The unit tests defined here check only the internals of the code generator and
  the functions exposed to the templating library.
  The correctness of the files and code generated by the code generator are
  tested through the build system and integration tests.
  """

  def test_host_call_invalid_textproto(self):
    invalid_textproto = 'unknown_proto { unknown_field: "unknown_value" }'
    with self.assertRaises(text_format.ParseError):
      code_generator.get_host_calls_dictionary(invalid_textproto)

  def test_host_call_missing_required_field(self):
    missing_type_textproto = ('host_calls { name: "fsync" return_type: "int" '
                              'parameters { name: "fd" }}')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(missing_type_textproto)

  def test_host_call_no_parameters(self):
    getpid_textproto = 'host_calls { name: "getpid" return_type: "pid_t" }'
    host_calls = code_generator.get_host_calls_dictionary(getpid_textproto)
    getpid_parameters = _get_parameters_proto(host_calls)
    self.assertEqual(
        '', code_generator.comma_separate_parameters(getpid_parameters))
    self.assertEqual('',
                     code_generator.comma_separate_arguments(getpid_parameters))

  def test_host_call_one_parameter(self):
    fsync_textproto = ('host_calls { name: "fsync" return_type: "int" '
                       'parameters { name: "fd" type: "int" }}')
    host_calls = code_generator.get_host_calls_dictionary(fsync_textproto)
    fsync_parameters = _get_parameters_proto(host_calls)
    self.assertEqual(
        'int fd', code_generator.comma_separate_parameters(fsync_parameters))
    self.assertEqual(
        'fd', code_generator.comma_separate_arguments(fsync_parameters))

  def test_host_call_multiple_parameters(self):
    shutdown_textproto = ('host_calls { name: "shutdown" return_type: "int" '
                          'parameters { name: "sockfd" type: "int" } '
                          'parameters { name: "how" type: "int" }}')
    host_calls = code_generator.get_host_calls_dictionary(shutdown_textproto)
    shutdown_parameters = _get_parameters_proto(host_calls)
    self.assertEqual(
        'int sockfd, int how',
        code_generator.comma_separate_parameters(shutdown_parameters))
    self.assertEqual(
        'sockfd, how',
        code_generator.comma_separate_arguments(shutdown_parameters))

  def test_host_call_parameters_with_string_annotations(self):
    read_textproto = ('host_calls { name: "read" return_type: "int" '
                      'parameters { name: "path" type: "const char *" '
                      'pointer_attributes { attribute: IN } '
                      'pointer_attributes { attribute: STRING }} '
                      'parameters { name: "owner" type: "uint32_t" } '
                      'parameters { name: "group" type: "uint32_t" }}')
    host_calls = code_generator.get_host_calls_dictionary(read_textproto)
    read_parameters = _get_parameters_proto(host_calls)
    self.assertEqual(
        'const char * path, uint32_t owner, uint32_t group',
        code_generator.comma_separate_parameters(read_parameters))
    self.assertEqual(
        'path, owner, group',
        code_generator.comma_separate_arguments(read_parameters))
    self.assertEqual(
        '[in, string] const char * path, uint32_t owner, uint32_t group',
        code_generator.comma_separate_bridge_parameters(read_parameters))

  def test_host_call_parameters_with_size_annotations(self):
    chown_textproto = ('host_calls { name: "read" return_type: "int32_t" '
                       'parameters { name: "fd" type: "int" } '
                       'parameters { name: "buf1" type: "void *" '
                       'pointer_attributes { attribute: OUT } '
                       'pointer_attributes { attribute: SIZE '
                       'attribute_expression: "100" }} '
                       'parameters { name: "buf2" type: "void *" '
                       'pointer_attributes { attribute: OUT } '
                       'pointer_attributes { attribute: SIZE '
                       'attribute_expression: "len" }} '
                       'parameters { name: "len" type: "size_t" }}')
    host_calls = code_generator.get_host_calls_dictionary(chown_textproto)
    chown_parameters = _get_parameters_proto(host_calls)
    self.assertEqual(
        'int fd, void * buf1, void * buf2, size_t len',
        code_generator.comma_separate_parameters(chown_parameters))
    self.assertEqual(
        'fd, buf1, buf2, len',
        code_generator.comma_separate_arguments(chown_parameters))
    self.assertEqual(
        'int fd, [out, size=100] void * buf1, [out, size=len] void * buf2, '
        'size_t len',
        code_generator.comma_separate_bridge_parameters(chown_parameters))

  def test_parameter_invalid_pointer_type(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "invalid_type" '
                 'pointer_attributes { attribute: STRING }}}')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)

  def test_parameter_missing_pointer_attributes(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "int *" }}')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)

  def test_parameter_duplicate_pointer_attributes(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "const char *" '
                 'pointer_attributes { attribute: STRING } '
                 'pointer_attributes { attribute: STRING }}}')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)

  def test_parameter_conflicting_copy_attributes(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "const char *" '
                 'pointer_attributes { attribute: IN } '
                 'pointer_attributes { attribute: USER_CHECK }}}')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)

  def test_parameter_conflicting_length_attributes(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "const char *" '
                 'pointer_attributes { attribute: STRING } '
                 'pointer_attributes { attribute: SIZE '
                 'attribute_expression: "10" }}}')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)

  def test_parameter_unnecessary_attribute_expression(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "const char *" '
                 'pointer_attributes { attribute: IN '
                 'attribute_expression: "1" }}} ')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)

  def test_parameter_invalid_attribute_expression(self):
    textproto = ('host_calls { name: "strlen" return_type: "int" '
                 'parameters { name: "s" type: "const char *" '
                 'pointer_attributes { attribute: SIZE '
                 'attribute_expression: "s" }}} ')
    with self.assertRaises(ValueError):
      code_generator.get_host_calls_dictionary(textproto)


if __name__ == '__main__':
  main()
